"""
Visualization script for final paper figures.
Author: Musashi Hinck

From Brandon:
    Simulation 1:
        For each of the three metrics (bias, RMSE, coverage).
        I want a version of the last column (gold standard accuracy = 100%) tracing along the x-axis the accuracy of the surrogate with lines for different sample sizes. 

    Simulation 2:
    

"""

# %% Imports
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

from matplotlib.lines import Line2D
from matplotlib.patches import Patch

matplotlib.rcParams["text.usetex"] = True
sns.set_style("whitegrid")


# %% Reused functions
# Data parsing
def transform_sim_data_to_long(sim_data: pd.DataFrame):
    """
    Input
      id_vars: n_label, gs_acc, q_acc
      columns: so_bias, dsl_bias, so_rmse, dsl_rmse, so_coverage, dsl_coverage
    Target:
      n_label, gs_acc, q_acc, metric (bias, rmse, coverage), estimator (so, dsl), value
    """
    # Melt the DataFrame
    plot_data = pd.melt(
        sim_data,
        id_vars=["n_label", "gs_acc", "q_acc"],
        value_vars=[
            "so_bias",
            "dsl_bias",
            "so_rmse",
            "dsl_rmse",
            "so_coverage",
            "dsl_coverage",
        ],
        var_name="variable",
        value_name="value",
    )
    # Split the 'variable' column into 'metric' and 'estimator'
    plot_data[["estimator", "metric"]] = plot_data["variable"].str.split(
        "_", expand=True
    )
    # Drop the original 'variable' column
    plot_data = plot_data.drop(columns=["variable"])
    # Reorder the columns to match the target format
    plot_data = plot_data[
        ["n_label", "gs_acc", "q_acc", "metric", "estimator", "value"]
    ]
    return plot_data


# %% Simulation 1
# From Brandon:
# For each of the three metrics (bias, RMSE, coverage).
# I want a version of the last column (gold standard accuracy = 100%) tracing along the x-axis the accuracy of the surrogate with lines for different sample sizes.

plot_data = transform_sim_data_to_long(
    pd.read_csv("results/final/sim1.csv").iloc[:, 1:]
)

# %% Visualizations
q_acc_vals = sorted(plot_data["q_acc"].unique())
gs_acc_vals = sorted(plot_data["gs_acc"].unique())
n_label_vals = sorted(plot_data["n_label"].unique())

plot_data.loc[:, "metric"] = plot_data["metric"].apply(
    lambda x: x.capitalize() if x != "rmse" else "RMSE"
)
yaxis_labels = {
    "Bias": "Mean Absolute Bias",
    "RMSE": "Root Mean Squared Error",
    "Coverage": "Nominal Coverage of 95\% CI",
}

# Define color palettes
palette_map = {"dsl": "Reds", "so": "Blues"}

# %% Figure 1
# Create the FacetGrid
g = sns.FacetGrid(
    plot_data,
    col="metric",
    sharex=True,
    sharey=False,
    legend_out=True,
    height=4,
)

# Plot the data
for ax in g.axes.flatten():
    metric = ax.get_title().split(" = ")[1]
    for estimator in ["dsl", "so"]:
        cmap = sns.color_palette(palette_map[estimator], n_colors=len(q_acc_vals) + 1)
        for qidx, q_acc in enumerate(q_acc_vals, start=1):
            df = plot_data.loc[
                plot_data["metric"].eq(metric)
                & plot_data["estimator"].eq(estimator)
                & plot_data["q_acc"].eq(q_acc),
                ["n_label", "value"],
            ]
            ax.plot(
                df["n_label"],
                df["value"],
                marker="o" if estimator == "dsl" else "s",
                color=cmap[qidx],
                label=f"{q_acc}",
            )
    ax.yaxis.set_label_text(yaxis_labels[metric])
    # ax.xaxis.set_label_text("N Labelled Samples")
    ax.set_xticks([0]+n_label_vals)
    if metric == "Coverage":
        ax.set_ylim(0.0, 1.05)
        ax.axhline(y=0.95, color="black", linestyle="--")


# Create custom legend
legend_elements = []

# Add SO legend entries
so_cmap = sns.color_palette(palette_map["so"], n_colors=len(q_acc_vals) + 1)
for qidx, q_acc in enumerate(q_acc_vals, start=1):
    legend_elements.append(
        Line2D([0], [0], color=so_cmap[qidx], lw=2, marker="s", label=f"{q_acc}")
    )

# Add DSL legend entries
dsl_cmap = sns.color_palette(palette_map["dsl"], n_colors=len(q_acc_vals) + 1)
for qidx, q_acc in enumerate(q_acc_vals, start=1):
    legend_elements.append(
        Line2D([0], [0], color=dsl_cmap[qidx], lw=2, marker="o", label=f"{q_acc}")
    )

# Add a title for each class
legend_elements.insert(0, Patch(color="none", label="SO", linewidth=0))
legend_elements.insert(7, Patch(color="none", label="DSL", linewidth=0))

# Add the custom legend to the plot
g.add_legend()
g.figure.legend(
    handles=legend_elements,
    title="Surrogate Accuracy",
    ncol=1,
    bbox_to_anchor=(1.1, 0.9),
)

# Adjust the layout
# g.set_axis_labels("N Labeled Examples", "{col_name}")
g.set_titles(col_template="{col_name}")
g.figure.supxlabel("N Labeled Examples", y=-0.02)
g.figure.suptitle("Effect of Surrogate Accuracy on Bias, RMSE, and Coverage for DSL and SO Estimators", y=1.05)

# Show the plot
g.figure.show()
g.figure.savefig("../results/final/sim1_figure1.pdf", bbox_inches="tight")

plot_data.loc[
    plot_data["metric"].eq("Coverage") &
    plot_data["estimator"].eq("so") &
    plot_data["q_acc"].eq(0.99)
]

# %% Simulation 2
# From Brandon:
# Plot 2: Showing that GS errors aren't a huge, huge deal.
# Again I want a variant for each of three metrics.  For gold standard errors, we use the benign process.  For surrogate accuracy we use whatever new malign process you come up with.  Each plot should use one surrogate accuracy level, but I'd love to see options for 75%, 80% and 85% accurate (which seem reasonable to me as surrogate accuracies) and then I'll choose one.  Again each metric-surrogate_accuracy combination should ideally be one plot.  I'll let you figure out how to do that, but I think the key is probably that gold standard accuracy is on the x-axis.  When plotting I'd run it backwards though (so 100% accurate gold standard is on the far left and human error increases as you move right).

# Naoki: fix accuracy, show how DSL cahnges when accuracy of gs changes between levels.
data = transform_sim_data_to_long(pd.read_csv("../results/final/sim2.csv").iloc[:, 1:])
data[['n_label', 'gs_acc', 'q_acc']].value_counts().reset_index()


# %% Figure actual
plot_data = data.loc[data["q_acc"].eq(0.75) & data["gs_acc"].ge(0.75), :].copy()
plot_data.loc[:, "metric"] = plot_data["metric"].apply(
    lambda x: x.capitalize() if x != "rmse" else "RMSE"
)
yaxis_labels = {
    "Bias": "Mean Absolute Bias",
    "RMSE": "Root Mean Squared Error",
    "Coverage": "Nominal Coverage of 95\% CI",
}


q_acc_vals = sorted(plot_data["q_acc"].unique())
gs_acc_vals = sorted(plot_data["gs_acc"].unique())
n_label_vals = sorted(plot_data["n_label"].unique())

palette_map = {"dsl": "Reds", "so": "Blues"}
blue = sns.color_palette("Blues", n_colors=1)[0]

# Create the FacetGrid
g = sns.FacetGrid(
    plot_data,
    col="metric",
    sharex=True,
    sharey=False,
    legend_out=True,
    height=4,
)

# Plot the data
for ax in g.axes.flatten():
    metric = ax.get_title().split(" = ")[1]

    # Iterate for DSL
    cmap = sns.color_palette("Reds", n_colors=len(gs_acc_vals) + 1)
    for gs_idx, gs_acc in enumerate(gs_acc_vals, start=1):
        df = plot_data.loc[
            plot_data["metric"].eq(metric)
            & plot_data["estimator"].eq("dsl")
            & plot_data["gs_acc"].eq(gs_acc),
            ["n_label", "value"],
        ]
        ax.plot(
            df["n_label"],
            df["value"],
            marker="o" if estimator == "dsl" else "s",
            color=cmap[gs_idx],
            label=f"{gs_acc}",
        )

    # Draw line for SO (it shouldn't vary as a function of gs_acc)
    df = plot_data.loc[
        plot_data["metric"].eq(metric)
        & plot_data["estimator"].eq("so")
        & plot_data["gs_acc"].eq(0.75),
        ["n_label", "value"],
    ]
    # ax.plot(df["n_label"], df["value"], color='b', label=f"SO")
    ax.axhline(y=df["value"].mean(), color=blue, linestyle="-.")
    ax.yaxis.set_label_text(yaxis_labels[metric])
    ax.set_xticks([0]+n_label_vals)
    # ax.xaxis.set_label_text("N Labelled Samples")
    if metric == "Coverage":
        ax.set_ylim(0.0, 1.05)
        ax.axhline(y=0.95, color="black", linestyle="--")

# Create custom legend
legend_elements = []
# First section is expert accuracy
legend_elements.append(Patch(color='none', label='Expert Accuracy', linewidth=0))

# Add DSL legend entries
dsl_cmap = sns.color_palette("Reds", n_colors=len(gs_acc_vals) + 1)
for gs_idx, gs_acc in enumerate(gs_acc_vals, start=1):
    legend_elements.append(
        Line2D([0], [0], color=dsl_cmap[gs_idx], lw=2, marker="o", label=f"{gs_acc}")
    )

# Second section is axhlines
legend_elements.append(Patch(color='none', label='Baselines', linewidth=0))
legend_elements.append(
    Line2D([0], [0], color=blue, linestyle='-.', label='SO Baseline')
)
legend_elements.append(
    Line2D([0], [0], color='k', linestyle='--', label='95\% Coverage')
)



# Add the custom legend to the plot
g.add_legend()
legend = g.figure.legend(
    handles=legend_elements,
    # title="Expert Accuracy",
    ncol=1,
    bbox_to_anchor=(1.13, 0.75),
)

sns.utils.adjust_legend_subtitles(legend)


# Adjust the layout
g.set_titles(col_template="{col_name}")
g.figure.supxlabel("N Labeled Examples", y=-0.05)
g.figure.suptitle("Effect of Errors in Expert Labels on DSL Estimator", y=1.05)

# # Show the plot
g.figure.show()
g.figure.savefig("../results/final/sim2_figure1.pdf", bbox_inches="tight")


## GRAVEYARD

# %% combined figure
palette_map = {"dsl": "Reds", "so": "Blues"}
g = sns.FacetGrid(
    plot_data, col="metric", sharex=True, sharey=False, legend_out=True, height=5
)
for ax in g.axes.flatten():
    metric = ax.get_title().split(" = ")[1]
    for estimator in ["so", "dsl"]:
        cmap = sns.color_palette(palette_map[estimator], n_colors=len(q_acc_vals) + 1)
        for qidx, q_acc in enumerate(q_acc_vals, start=1):
            df = plot_data.loc[
                plot_data["metric"].eq(metric)
                & plot_data["estimator"].eq(estimator)
                & plot_data["q_acc"].eq(q_acc),
                ["n_label", "value"],
            ]
            ax.plot(
                df["n_label"],
                df["value"],
                marker="o" if estimator == "dsl" else None,
                color=cmap[qidx],
                label=f"{q_acc}",
            )
    legend = g.axes[-1, -1].legend(
        title="Surrogate Accuracy\nSO \\bigskip DSL", ncol=2, bbox_to_anchor=(1.05, 0.7)
    )
    handles, texts = ax.get_legend_handles_labels()
# texts, handles = legend.get_legend_handles_labels()
# Legend: 2 cols, each showing


# g.map_dataframe(sns.lineplot, "n_label", "value", marker="o", aggfunc=None)
# g.map(sns.lineplot, "n_label", "value", marker="o")
# g.add_legend(title="Estimator")
# # g.map(plt.axhline, y=0.95, color="black", linestyle="--")
# g.set_xlabels("")
# g.set_ylabels("")
# g.figure.supxlabel(r"N Labelled")#, y=-0.02)
# g.figure.supylabel(r"Mean Abs Bias")#, x=-0.02)
# g.figure.suptitle("SO vs DSL Bias for Varying Gold-Standard and Surrogate Accuracy", y=1.05)


# %% relplot?
g = sns.relplot(
    data=plot_data,
    x="n_label",
    y="value",
    col="metric",
    style="estimator",
    style_order=["dsl", "so"],
    kind="line",
    hue="q_acc",
    palette="Blues",
    facet_kws=dict(sharex=True, sharey=False),
    markers="o",
)
g.set_xlabels("N Labelled Samples")
for ax in g.axes.flatten():
    ax.set_xticks(n_label_vals)


# %% Different version - separate figure for each metric
g = sns.FacetGrid(
    plot_data.loc[plot_data["metric"].eq("bias"), :],
    col="gs_acc",
    row="q_acc",
    hue="estimator",
    sharex=True,
    sharey="row",
    margin_titles=True,
    height=1,
    aspect=1,
)
g.map(sns.lineplot, "n_label", "value", marker="o")
g.add_legend(title="Estimator")
g.set_xlabels("")
g.set_ylabels("")
g.figure.supxlabel(r"N Labelled")  # , y=-0.02)
g.figure.supylabel(r"Mean Abs Bias")  # , x=-0.02)
g.figure.suptitle(
    "SO vs DSL Bias for Varying Gold-Standard and Surrogate Accuracy", y=1.05
)

# %%
g = sns.FacetGrid(
    plot_data.loc[plot_data["metric"].eq("rmse"), :],
    col="gs_acc",
    row="q_acc",
    hue="estimator",
    sharex=True,
    sharey="row",
    margin_titles=True,
    height=1,
    aspect=1,
)
g.map(sns.lineplot, "n_label", "value", marker="o")
g.add_legend(title="Estimator")
g.set_xlabels("")
g.set_ylabels("")
g.figure.supxlabel(r"N Labelled")  # , y=-0.02)
g.figure.supylabel(r"Mean RMSE")  # , x=-0.02)
g.figure.suptitle(
    "SO vs DSL RMSE for Varying Gold-Standard and Surrogate Accuracy", y=1.05
)


# %%
g = sns.FacetGrid(
    plot_data.loc[plot_data["metric"].eq("coverage"), :],
    col="gs_acc",
    row="q_acc",
    hue="estimator",
    sharex=True,
    sharey="row",
    margin_titles=True,
    height=1,
    aspect=1,
)
g.map(sns.lineplot, "n_label", "value", marker="o")
g.map(plt.axhline, y=0.95, color="black", linestyle="--")
g.add_legend(title="Estimator")
g.set_xlabels("")
g.set_ylabels("")
g.figure.supxlabel(r"N Labelled")  # , y=-0.02)
g.figure.supylabel(r"Coverage")  # , x=-0.02)
g.figure.suptitle(
    "SO vs DSL Coverage for Varying Gold-Standard and Surrogate Accuracy", y=1.05
)
